# Copyright 2021 Wechat Group, Tencent
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import test


class LayerNormalization(test.Test):
    @property
    def classPrefix(self):
        return "layer_normalization"

    def getRandomTestData(self):
        dims = np.random.randint(1, 5)
        shape = np.random.randint(1, 10, dims)
        axis = np.random.randint(0, dims)
        a = np.random.uniform(-1.0, 1.0, shape)
        gamma = np.random.uniform(-1.0, 1.0, shape[axis:])
        beta = np.random.uniform(-1.0, 1.0, shape[axis:])

        meanShape = list(a.shape)
        for i in range(axis, dims):
            meanShape[i] = 1

        mean = np.add.reduce(a, axis=tuple(range(axis, dims))) / np.multiply.reduce(
            shape[axis:]
        )
        mean = np.reshape(mean, meanShape)
        variance = (a - mean) * (a - mean)
        variance = np.add.reduce(
            variance, axis=tuple(range(axis, dims))
        ) / np.multiply.reduce(shape[axis:])
        variance = np.reshape(variance, meanShape)

        inv = 1.0 / np.sqrt(variance + 1e-12) * gamma
        expect = a * inv + beta - mean * inv

        return {
            "axis": np.asarray([axis], dtype=np.int32),
            "a": a,
            "LayerNorm/gamma": gamma,
            "LayerNorm/beta": beta,
            "expect": expect,
        }
